import os
import json
import torch
import tqdm
import hydra
import hydra.utils as hu
from transformers import set_seed
from torch.utils.data import DataLoader

from src.utils.collators import DataCollatorWithPaddingAndCuda
from src.models.biencoder import BiEncoder

# 1) import your kernelized selector
from kernel_module_gpu import submodular_kernel_select_gpu, submodular_poly_kernel_select_gpu


@hydra.main(config_path="configs", config_name="kernel_retriever")
def main(cfg):
    # reproducibility + cudnn tuning
    set_seed(getattr(cfg, 'seed', 42))
    torch.backends.cudnn.benchmark = True

    # --- Model setup ---
    model_cfg = hu.instantiate(cfg.model_config)
    if cfg.pretrained_model_path:  # value 0 is False
        print(f"Loading model from: {cfg.pretrained_model_path}")
        model = BiEncoder.from_pretrained(cfg.pretrained_model_path, config=model_cfg)
    else:
        model = BiEncoder(model_cfg)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    # --- DataLoaders with pinned memory & workers ---
    index_reader = hu.instantiate(cfg.index_reader)
    collator     = DataCollatorWithPaddingAndCuda(
        tokenizer=index_reader.tokenizer,
        device=device
    )
    index_loader = DataLoader(
        index_reader,
        batch_size=cfg.batch_size,
        collate_fn=collator
    )

    # --- Build index embeddings ON GPU ---
    index_embeds_list = []
    index_metadata    = []
    for batch in tqdm.tqdm(index_loader, desc="Encoding index passages"):
        with torch.no_grad():
            embeds = model.encode(
                batch["input_ids"],
                batch["attention_mask"],
                encode_ctx=True
            )
        index_embeds_list.append(embeds)      # still on GPU
        index_metadata.extend(batch.get("metadata", []).data)

    demo_embeds = torch.cat(index_embeds_list, dim=0)  # (n, d) on GPU

    # --- Build query embeddings ON GPU ---
    query_reader = hu.instantiate(cfg.dataset_reader)
    query_loader = DataLoader(
        query_reader,
        batch_size=cfg.batch_size,
        collate_fn=collator
    )

    query_embeds_list = []
    query_metadata    = []
    for batch in tqdm.tqdm(query_loader, desc="Encoding queries"):
        with torch.no_grad():
            embeds = model.encode(batch["input_ids"], batch["attention_mask"])
        query_embeds_list.append(embeds)
        query_metadata.extend(batch.get("metadata", []).data)

    query_embeds = torch.cat(query_embeds_list, dim=0)  # (m, d) on GPU

    # --- Submodular-selection loop (now using kernel method) ---
    results = []
    for idx, (q_embed, meta) in tqdm.tqdm(
            enumerate(zip(query_embeds, query_metadata)),
            total=len(query_embeds),
            desc="Selecting contexts"
    ):
        if cfg.run_for_n_samples and idx >= cfg.run_for_n_samples:
            break
        
        if cfg.use_polynomial_kernel:
            # 1) replace the old call with the kernelized version
            selected_idxs = submodular_poly_kernel_select_gpu(
                demo_embeds=demo_embeds,        # (n, d) on GPU
                test_embed=q_embed,             # (d,)   on GPU
                k=cfg.num_ice,                  # how many to pick
                lambd=cfg.lambd,                # diversity trade-off λ
            )
        else:
            # 2) replace the old call with the kernelized version
            selected_idxs = submodular_kernel_select_gpu(
                demo_embeds=demo_embeds,        # (n, d) on GPU
                test_embed=q_embed,             # (d,)   on GPU
                k=cfg.num_ice,                  # how many to pick
                lambd=cfg.lambd,                # diversity trade-off λ
                beta=cfg.beta,                  # regularization β (Eq. 11)
                lengthscale=cfg.lengthscale     # RBF kernel ℓ (Sec. 2.2)
            )

        orig = query_reader.dataset_wrapper[meta["id"]].copy()
        orig["ctxs"]            = selected_idxs
        orig["ctxs_candidates"] = [[i] for i in selected_idxs]
        results.append(orig)

    # --- Write out JSON ---
    os.makedirs(os.path.dirname(cfg.output_file), exist_ok=True)
    with open(cfg.output_file, "w") as fout:
        json.dump(results, fout, indent=2)


if __name__ == "__main__":
    main()
